model.py

  • Class DistributedInferenceBaseModel

    • _generate_output

      • 直到生成max new tokens

        • outputs = self.model(
                              input_ids,
                              position_ids=position_ids,
                              past_key_values=past_key_values,
                              use_cache=True,
                              enable_star_attn=True,
                          )  # type: ignore
          
        • 在最后一个rank上更新past_key_values

        • logits获取新的token加入到output中

        • 更新input和position

          • 更新的数据是整个query再次传进去的
          • query只添加到最后一个rank上,在query生成比较长的时候可能会出现负载均衡问题
  • class StarAttentionModel

    • def _tokenize_and_partition_context
      • 将输入padding为可整除并转化为tokens和positions两个tensor
    • def _process_blockwise_context
      • 每个rank做prefill,循环n次,每次计算一个block_size大小的,anchor block选择为rank0的第一个block
      • 返回当前rank的kv cache
    • def __call__
      • 生成长文本的KV Cache
        • 调用_tokenize_and_partition_context获取ctx_ids, position_ids
        • 将ctx_ids拆分为world_size个tensor类型的ctx_ids_blocks,每个tensor的形状是[-1,1,block_size]
        • position_ids同理转化为position_ids_blocks
        • 调用_tokenize_and_partition_context生成当前rank的kv cache
      • 生成Query
        • embedding
        • 调用_generate_output生成结果

modeling_llama.py

results matching ""

    No results matching ""